{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I1201 14:22:26.198628 139995368892224 file_utils.py:39] PyTorch version 0.4.1 available.\n",
      "I1201 14:22:27.596594 139995368892224 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c\n",
      "I1201 14:22:27.598366 139995368892224 configuration_utils.py:168] Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"finetuning_task\": \"sst-2\",\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"num_labels\": 1,\n",
      "  \"output_attentions\": false,\n",
      "  \"output_hidden_states\": false,\n",
      "  \"pruned_heads\": {},\n",
      "  \"torchscript\": false,\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_bfloat16\": false,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "I1201 14:22:28.443245 139995368892224 tokenization_utils.py:373] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /root/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.insert(0,'./data')\n",
    "\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.optim\n",
    "import torch.utils.data\n",
    "# import torchvision.transforms as transforms\n",
    "#from data.load_datasets import *\n",
    "from load_datasets_final import CaptionDataset\n",
    "from experiment.utils import *\n",
    "from nltk.translate.bleu_score import corpus_bleu\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "from nlgeval import NLGEval\n",
    "from transformers import (WEIGHTS_NAME, BertConfig,\n",
    "                                  BertForSequenceClassification, BertTokenizer,\n",
    "                                  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/torch/serialization.py:425: SourceChangeWarning: source code of class 'models.decoder_with_attention_final.DecoderWithAttention' has changed. Tried to save a patch, but couldn't create a writable file DecoderWithAttention.patch. Make sure it doesn't exist and your working directory is writable.\n",
      "  warnings.warn(msg, SourceChangeWarning)\n",
      "/usr/local/lib/python3.6/dist-packages/torch/serialization.py:425: SourceChangeWarning: source code of class 'transformers.modeling_bert.BertModel' has changed. Tried to save a patch, but couldn't create a writable file BertModel.patch. Make sure it doesn't exist and your working directory is writable.\n",
      "  warnings.warn(msg, SourceChangeWarning)\n",
      "I1201 14:22:32.531902 139995368892224 tokenization_utils.py:373] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /root/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n"
     ]
    }
   ],
   "source": [
    "# Parameters\n",
    "data_folder = 'preprocessed_data'  # folder with data files saved by create_input_files.py\n",
    "data_name = 'preprocessed_coco'  # base name shared by data files\n",
    "#'ckpt/BERT_3.pth.tar' \n",
    "checkpoint_file = 'ckpt/BERT_3.pth.tar'\n",
    "#checkpoint_file = 'BEST_34checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar'  # model checkpoint\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")  # sets device for model and PyTorch tensors\n",
    "cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead\n",
    "\n",
    "# Load model\n",
    "torch.nn.Module.dump_patches = True\n",
    "checkpoint = torch.load(checkpoint_file, map_location = device)\n",
    "decoder = checkpoint['decoder']\n",
    "decoder = decoder.to(device)\n",
    "decoder.eval()\n",
    "\n",
    "nlgeval = NLGEval()  # loads the evaluator\n",
    "\n",
    "#BERT Tokenizer\n",
    "model_name_or_path = \"bert-base-uncased\" \n",
    "tokenizer_class = BertTokenizer\n",
    "tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case = True)\n",
    "\n",
    "BERT_VOCA_SIZE = 30522\n",
    "vocab_size = BERT_VOCA_SIZE #len(word_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EVALUATING AT BEAM SIZE 5:   1%|▏         | 341/25000 [00:59<1:10:57,  5.79it/s]Process Process-1:\n",
      "Traceback (most recent call last):\n",
      "  File \"/usr/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n",
      "    self.run()\n",
      "  File \"/usr/lib/python3.6/multiprocessing/process.py\", line 93, in run\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "  File \"/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\", line 96, in _worker_loop\n",
      "    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)\n",
      "  File \"/usr/lib/python3.6/multiprocessing/queues.py\", line 104, in get\n",
      "    if not self._poll(timeout):\n",
      "  File \"/usr/lib/python3.6/multiprocessing/connection.py\", line 257, in poll\n",
      "    return self._poll(timeout)\n",
      "  File \"/usr/lib/python3.6/multiprocessing/connection.py\", line 414, in _poll\n",
      "    r = wait([self], timeout)\n",
      "  File \"/usr/lib/python3.6/multiprocessing/connection.py\", line 911, in wait\n",
      "    ready = selector.select(timeout)\n",
      "  File \"/usr/lib/python3.6/selectors.py\", line 376, in select\n",
      "    fd_event_list = self._poll.poll(timeout)\n",
      "KeyboardInterrupt\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-3-e2eecb7e25d2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     86\u001b[0m         \u001b[0;31m# Add new words to sequences\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m         \u001b[0mseqs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mseqs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mprev_word_inds\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_word_inds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# (s, step+1)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     89\u001b[0m         \u001b[0;31m# Which sequences are incomplete (didn't reach <end>)?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Evaluation\n",
    ":param beam_size: beam size at which to generate captions for evaluation\n",
    ":return: Official MSCOCO evaluator scores - bleu4, cider, rouge, meteor\n",
    "\"\"\"\n",
    "# DataLoader\n",
    "loader = torch.utils.data.DataLoader(\n",
    "    CaptionDataset(data_folder, data_name, 'TEST'),\n",
    "    batch_size=1, shuffle=False, num_workers=1, pin_memory=torch.cuda.is_available())\n",
    "\n",
    "# Lists to store references (true captions), and hypothesis (prediction) for each image\n",
    "# If for n images, we have n hypotheses, and references a, b, c... for each image, we need -\n",
    "# references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]\n",
    "references = list()\n",
    "hypotheses = list()\n",
    "\n",
    "CLS_IDX = 101\n",
    "SEP_IDX = 102\n",
    "PAD_IDX = 0\n",
    "start_idx = CLS_IDX\n",
    "end_idx = SEP_IDX\n",
    "beam_size = 5\n",
    "\n",
    "# For each image\n",
    "for i, (image_features, caps, caplens, allcaps) in enumerate(\n",
    "        tqdm(loader, desc=\"EVALUATING AT BEAM SIZE \" + str(beam_size))):\n",
    "    k = beam_size\n",
    "\n",
    "    # Move to GPU device, if available\n",
    "    image_features = image_features.to(device)  # (1, 3, 256, 256)\n",
    "    image_features_mean = image_features.mean(1)\n",
    "    image_features_mean = image_features_mean.expand(k,2048)\n",
    "\n",
    "    # Tensor to store top k previous words at each step; now they're just <start>\n",
    "    CLS_IDX = 101\n",
    "    start_idx = CLS_IDX\n",
    "    k_prev_words = torch.LongTensor([[start_idx]] * k).to(device)  # (k, 1)\n",
    "\n",
    "    # Tensor to store top k sequences; now they're just <start>\n",
    "    seqs = k_prev_words  # (k, 1)\n",
    "\n",
    "    # Tensor to store top k sequences' scores; now they're just 0\n",
    "    top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)\n",
    "\n",
    "    # Lists to store completed sequences and scores\n",
    "    complete_seqs = list()\n",
    "    complete_seqs_scores = list()\n",
    "\n",
    "    # Start decoding\n",
    "    step = 1\n",
    "    h1, c1 = decoder.init_hidden_state(k)  # (batch_size, decoder_dim)\n",
    "    h2, c2 = decoder.init_hidden_state(k)\n",
    "\n",
    "    # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>\n",
    "    while True:\n",
    "        #tensor((tuple))\n",
    "        \n",
    "        embeddings = decoder.embedding(k_prev_words)\n",
    "        \n",
    "        #print(embeddings[0].size())\n",
    "        embeddings = embeddings.squeeze(1)  # (s, embed_dim)\n",
    "        h1,c1 = decoder.top_down_attention(\n",
    "            torch.cat([h2,image_features_mean,embeddings], dim=1),\n",
    "            (h1,c1))  # (batch_size_t, decoder_dim)\n",
    "        attention_weighted_encoding = decoder.attention(image_features,h1)\n",
    "        h2,c2 = decoder.language_model(\n",
    "            torch.cat([attention_weighted_encoding,h1], dim=1),(h2,c2))\n",
    "\n",
    "        scores = decoder.fc(h2)  # (s, vocab_size)\n",
    "        scores = F.log_softmax(scores, dim=1)\n",
    "\n",
    "        # Add\n",
    "        scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)\n",
    "\n",
    "        # For the first step, all k points will have the same scores (since same k previous words, h, c)\n",
    "        if step == 1:\n",
    "            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)\n",
    "        else:\n",
    "            # Unroll and find top scores, and their unrolled indices\n",
    "            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)\n",
    "\n",
    "        # Convert unrolled indices to actual indices of scores\n",
    "        prev_word_inds = top_k_words / vocab_size  # (s)\n",
    "        next_word_inds = top_k_words % vocab_size  # (s)\n",
    "\n",
    "        # Add new words to sequences\n",
    "        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)\n",
    "\n",
    "        # Which sequences are incomplete (didn't reach <end>)?\n",
    "        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if\n",
    "                           next_word != end_idx]\n",
    "        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))\n",
    "\n",
    "        # Set aside complete sequences\n",
    "        if len(complete_inds) > 0:\n",
    "            complete_seqs.extend(seqs[complete_inds].tolist())\n",
    "            complete_seqs_scores.extend(top_k_scores[complete_inds])\n",
    "        k -= len(complete_inds)  # reduce beam length accordingly\n",
    "\n",
    "        # Proceed with incomplete sequences\n",
    "        if k == 0:\n",
    "            break\n",
    "        seqs = seqs[incomplete_inds]\n",
    "        h1 = h1[prev_word_inds[incomplete_inds]]\n",
    "        c1 = c1[prev_word_inds[incomplete_inds]]\n",
    "        h2 = h2[prev_word_inds[incomplete_inds]]\n",
    "        c2 = c2[prev_word_inds[incomplete_inds]]\n",
    "        image_features_mean = image_features_mean[prev_word_inds[incomplete_inds]]\n",
    "        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)\n",
    "        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)\n",
    "\n",
    "        # Break if things have been going on too long\n",
    "        if step > 50:\n",
    "            break\n",
    "        step += 1\n",
    "\n",
    "    i = complete_seqs_scores.index(max(complete_seqs_scores))\n",
    "    seq = complete_seqs[i]\n",
    "\n",
    "# CLS_IDX = 101\n",
    "# SEP_IDX = 102\n",
    "# PAD_IDX = 0\n",
    "# start_idx = CLS_IDX\n",
    "# end_idx = SEP_IDX\n",
    "#tokenizer._convert_id_to_token(102)\n",
    "    # References\n",
    "    #원래꺼가 [[]]이런식으로 되어있었나보네, 변경 > [[],[],[],[]]\n",
    "    img_caps = allcaps[0][0].tolist()\n",
    "    #print(len(img_caps))\n",
    "    #print(img_caps[0][0])\n",
    "    img_captions = list(\n",
    "        map(lambda c: [tokenizer._convert_id_to_token(w) for w in c if w not in {start_idx, end_idx, PAD_IDX}],\n",
    "            img_caps))  # remove <start> and pads\n",
    "    img_caps = [' '.join(c) for c in img_captions]\n",
    "    #print(img_caps)\n",
    "    references.append(img_caps)\n",
    "\n",
    "    # Hypotheses\n",
    "    hypothesis = ([tokenizer._convert_id_to_token(w) for w in seq if w not in {start_idx, end_idx, PAD_IDX}])\n",
    "    hypothesis = ' '.join(hypothesis)\n",
    "    #print(hypothesis)\n",
    "    hypotheses.append(hypothesis)\n",
    "    assert len(references) == len(hypotheses)\n",
    "\n",
    "# Calculate scores\n",
    "metrics_dict = nlgeval.compute_metrics(references, hypotheses)\n",
    "print(metrics_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
